{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "51d693b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "import numpy as np\n",
    "from tvm.contrib import graph_executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f4ec3a18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct BN\n",
    "def batch_norm(data,\n",
    "                     gamma=None,\n",
    "                     beta=None,\n",
    "                     moving_mean=None,\n",
    "                     moving_var=None,\n",
    "                     **kwargs):\n",
    "    name = kwargs.get(\"name\")\n",
    "    kwargs.pop(\"name\")\n",
    "    if not gamma:\n",
    "        gamma = relay.var(name + \"_gamma\")\n",
    "    if not beta:\n",
    "        beta = relay.var(name + \"_beta\")\n",
    "    if not moving_mean:\n",
    "        moving_mean = relay.var(name + \"_moving_mean\")\n",
    "    if not moving_var:\n",
    "        moving_var = relay.var(name + \"_moving_var\")\n",
    "    return relay.nn.batch_norm(data,\n",
    "                               gamma=gamma,\n",
    "                               beta=beta,\n",
    "                               moving_mean=moving_mean,\n",
    "                               moving_var=moving_var,\n",
    "                               **kwargs)[0]\n",
    "\n",
    "# construct convolution\n",
    "def conv2d(data, weight=None, **kwargs):\n",
    "    name = kwargs.get(\"name\")\n",
    "    kwargs.pop(\"name\")\n",
    "    if not weight:\n",
    "        weight = relay.var(name + \"_weight\")\n",
    "    return relay.nn.conv2d(data, weight, **kwargs)\n",
    "\n",
    "\n",
    "# simplenet: conv+bn+relu\n",
    "def simplenet(data, name, channels, kernel_size=(3, 3), strides=(1, 1),\n",
    "               padding=(1, 1), epsilon=1e-5):\n",
    "    conv = conv2d(\n",
    "        data=data,\n",
    "        channels=channels,\n",
    "        kernel_size=kernel_size,\n",
    "        strides=strides,\n",
    "        padding=padding,\n",
    "        data_layout='NCHW',\n",
    "        name=name+'_conv')\n",
    "    bn = batch_norm(data=conv, epsilon=epsilon, name=name + '_bn')\n",
    "    act = relay.nn.relu(data=bn)\n",
    "    return act\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1bd63673",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_shape = (1, 3, 224, 224)\n",
    "kernel_shape = (32, 3, 3, 3)\n",
    "dtype = \"float32\"\n",
    "data = relay.var(\"data\", shape=data_shape, dtype=dtype)\n",
    "act = simplenet(data, \"graph\", 32, strides=(2, 2))\n",
    "func = relay.Function(relay.analysis.free_vars(act), act)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7575f1f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "free_var %data: Tensor[(1, 3, 224, 224), float32];\n",
      "free_var %graph_conv_weight;\n",
      "%0 = nn.conv2d(%data, %graph_conv_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]);\n",
      "free_var %graph_bn_gamma;\n",
      "free_var %graph_bn_beta;\n",
      "free_var %graph_bn_moving_mean;\n",
      "free_var %graph_bn_moving_var;\n",
      "%1 = nn.batch_norm(%0, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var);\n",
      "%2 = %1.0;\n",
      "nn.relu(%2)\n"
     ]
    }
   ],
   "source": [
    "print(act)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c3bc37a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var) {\n",
      "  %0 = nn.conv2d(%data, %graph_conv_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]);\n",
      "  %1 = nn.batch_norm(%0, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var);\n",
      "  %2 = %1.0;\n",
      "  nn.relu(%2)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f3fd9010",
   "metadata": {},
   "outputs": [],
   "source": [
    "np_data = np.random.uniform(-1, 1, (1, 3, 224, 224))\n",
    "params = {\n",
    "    \"graph_conv_weight\": tvm.nd.array(np.random.uniform(-1, 1, (32, 3, 3, 3)).astype(dtype)),\n",
    "    \"graph_bn_gamma\": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),\n",
    "    \"graph_bn_beta\": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),\n",
    "    \"graph_bn_moving_mean\": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),\n",
    "    \"graph_bn_moving_var\": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),\n",
    "}\n",
    "with tvm.transform.PassContext(opt_level=10):\n",
    "    lib = relay.build(func, \"llvm\", params=params)\n",
    "\n",
    "\n",
    "dev = tvm.cpu(0)\n",
    "dtype = \"float32\"\n",
    "m = graph_executor.GraphModule(lib[\"default\"](dev))\n",
    "# set inputs\n",
    "m.set_input(\"data\", tvm.nd.array(np_data.astype(dtype)))\n",
    "# execute\n",
    "m.run()\n",
    "# get outputs\n",
    "tvm_output = m.get_output(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ee15c8fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tvm.relay.backend.executor_factory.GraphExecutorFactoryModule object at 0x7f2f4f474460>\n"
     ]
    }
   ],
   "source": [
    "print(lib)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1db8c57",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
